{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QTJt9pwUTbHo"
   },
   "source": [
    "# Intelligent Synthetic Dataset Generator\n",
    "\n",
    "An AI-powered tool that creates realistic synthetic datasets for any business case—whether you provide the schema or let it intelligently design one for you.\n",
    "\n",
    "It works with Claude, Gemini, GPT and HugginFace APIs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "l_FljmlTUoka"
   },
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "aONqZ-SjUJdg",
    "outputId": "1f5c7b2e-95f0-4f23-cf01-2bd5bda0807a"
   },
   "outputs": [],
   "source": [
    "!pip install -q requests bitsandbytes anthropic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Ub1unBFvTatE"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import requests\n",
    "import json\n",
    "from google.colab import userdata\n",
    "\n",
    "from openai import OpenAI\n",
    "import anthropic\n",
    "from huggingface_hub import login\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer, BitsAndBytesConfig\n",
    "import torch\n",
    "import pandas as pd\n",
    "\n",
    "import gradio as gr\n",
    "import gc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "viZNPtObUOcz"
   },
   "outputs": [],
   "source": [
    "hf_token = userdata.get('HF_TOKEN')\n",
    "openai_api_key = userdata.get('OPENAI_API_KEY')\n",
    "anthropic_api_key = userdata.get('ANTHROPIC_API_KEY')\n",
    "google_api_key = userdata.get('GOOGLE_API_KEY')\n",
    "\n",
    "login(hf_token, add_to_git_credential=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9Q94S6JTUWn5"
   },
   "outputs": [],
   "source": [
    "quant_config = BitsAndBytesConfig(\n",
    "    load_in_4bit=True,\n",
    "    bnb_4bit_use_double_quant=True,\n",
    "    bnb_4bit_compute_dtype=torch.bfloat16,\n",
    "    bnb_4bit_quant_type=\"nf4\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mrjdVEpaUxHz"
   },
   "source": [
    "## Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "LvNE6foEUPaz"
   },
   "outputs": [],
   "source": [
    "LLAMA = \"meta-llama/Meta-Llama-3.1-8B-Instruct\"\n",
    "PHI3 = \"microsoft/Phi-3-mini-4k-instruct\"\n",
    "GEMMA2 = \"google/gemma-2-2b-it\"\n",
    "GPT = \"gpt-4o-mini\"\n",
    "CLAUDE = \"claude-3-haiku-20240307\"\n",
    "GEMINI = \"gemini-2.0-flash\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tvafTFD8XmaO"
   },
   "outputs": [],
   "source": [
    "MODELS = {\n",
    "    'LLama 3.1' : LLAMA,\n",
    "    'Phi 3 mini': PHI3,\n",
    "    'Gemma 2': GEMMA2,\n",
    "    'GPT 4.o mini': GPT,\n",
    "    'Claude 3 Haiku': CLAUDE,\n",
    "    'Gemini 2.0 Flash': GEMINI,\n",
    "}\n",
    "\n",
    "HF_MODELS = [LLAMA, PHI3, GEMMA2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "2LZqA9QXXl0t"
   },
   "outputs": [],
   "source": [
    "FILE_FORMATS = [\".csv\", \".tsv\", \".jsonl\", \".json\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "d6EnN7SVXhza",
    "outputId": "55f6ac4d-adeb-4216-b2a8-d67524b005d3"
   },
   "outputs": [],
   "source": [
    "SCHEMA = [\n",
    "    (\"Name\", \"TEXT\", \"Name of the restaurant\", \"Blue River Bistro\"),\n",
    "    (\"Address\", \"TEXT\", \"Restaurant address\", \"742 Evergreen Terrace, Springfield, IL 62704\"),\n",
    "    (\"Type\", \"TEXT\", \"Kitchen type\", 'One of [\"Thai\",\"Mediterranean\",\"Vegan\",\"Steakhouse\",\"Japanese\"] or other potential types'),\n",
    "    (\"Average Price\", \"TEXT\", \"Average meal price\", \"$45, or '--' if unknown\"),\n",
    "    (\"Year\", \"INT\", \"Year of restaurant opening\", 2015),\n",
    "    (\"Menu\", \"Array\", \"List of meals\", '[\"Grilled Salmon\", \"Caesar Salad\", \"Pad Thai\", \"Margherita Pizza\", ...]'),\n",
    "]\n",
    "\n",
    "DEFAULT_SCHEMA_TEXT = \"\\n\".join([f\"{i+1}. {col[0]} ({col[1]}) - {col[2]}, example: {col[3]}\" for i, col in enumerate(SCHEMA)])\n",
    "print(DEFAULT_SCHEMA_TEXT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "W-46TDTOXiS7"
   },
   "outputs": [],
   "source": [
    "system_prompt = \"\"\"\n",
    "You are an expert in generating synthetic datasets tailored to a given business case and user requirements.\n",
    "If the user does not specify output columns, infer and create the most appropriate columns based on your expertise.\n",
    "Do NOT repeat column values from one row to another. Only output valid JSONL without any comments.\"\n",
    "\"\"\"\n",
    "\n",
    "\n",
    "def get_user_prompt(business_case, schema_text, nr_records):\n",
    "    prompt = f\"The business case is: {business_case}.\\nGenerate {nr_records} rows of data in JSONL format.\\n\"\n",
    "\n",
    "    if schema_text is not None:\n",
    "      prompt += f\"Each line should be a JSON object with the following fields: \\n{schema_text}\\n\"\n",
    "\n",
    "    return prompt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gPf1GcAwhwa_"
   },
   "source": [
    "## LLM handler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Tf-WEQUKhY-z"
   },
   "outputs": [],
   "source": [
    "def ask_gpt(model: str, user_prompt: str):\n",
    "  client = OpenAI(api_key=openai_api_key)\n",
    "  messages = [\n",
    "      {\"role\": \"system\", \"content\": system_prompt},\n",
    "      {\"role\": \"user\", \"content\": user_prompt}\n",
    "    ]\n",
    "  response = client.chat.completions.create(\n",
    "      model=model,\n",
    "      messages=messages,\n",
    "      temperature=0.7\n",
    "  )\n",
    "  content = response.choices[0].message.content\n",
    "\n",
    "  return content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "856pnIHahzDd"
   },
   "outputs": [],
   "source": [
    "def ask_claude(model: str, user_prompt: str):\n",
    "  client = anthropic.Anthropic(api_key=anthropic_api_key)\n",
    "  response = client.messages.create(\n",
    "      model=model,\n",
    "      messages=[{\"role\": \"user\", \"content\": user_prompt}],\n",
    "      max_tokens=4000,\n",
    "      temperature=0.7,\n",
    "      system=system_prompt\n",
    "  )\n",
    "  content = response.content[0].text\n",
    "\n",
    "  return content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "p0AfSbcBiUlg"
   },
   "outputs": [],
   "source": [
    "def ask_gemini(model: str, user_prompt: str):\n",
    "  client = OpenAI(\n",
    "      api_key=google_api_key,\n",
    "      base_url=\"https://generativelanguage.googleapis.com/v1beta/openai/\"\n",
    "  )\n",
    "  messages = [\n",
    "      {\"role\": \"system\", \"content\": system_prompt},\n",
    "      {\"role\": \"user\", \"content\": user_prompt}\n",
    "    ]\n",
    "  response = client.chat.completions.create(\n",
    "      model=model,\n",
    "      messages=messages,\n",
    "      temperature=0.7\n",
    "  )\n",
    "  content = response.choices[0].message.content\n",
    "\n",
    "  return content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "K9LZZPJ9irrH"
   },
   "outputs": [],
   "source": [
    "def ask_hf(model: str, user_prompt: str):\n",
    "  global tokenizer, inputs, hf_model, outputs\n",
    "\n",
    "  messages = [\n",
    "        {\"role\": \"system\", \"content\": system_prompt},\n",
    "        {\"role\": \"user\", \"content\": user_prompt}\n",
    "      ]\n",
    "\n",
    "  tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)\n",
    "  tokenizer.pad_token = tokenizer.eos_token\n",
    "  inputs = tokenizer.apply_chat_template(messages, return_tensors=\"pt\").to(\"cuda\")\n",
    "  if hf_model == None:\n",
    "      hf_model = AutoModelForCausalLM.from_pretrained(model, device_map=\"auto\", quantization_config=quant_config)\n",
    "  outputs = hf_model.generate(inputs, max_new_tokens=4000)\n",
    "\n",
    "  _, _, after = tokenizer.decode(outputs[0]).partition(\"assistant<|end_header_id|>\")\n",
    "  content = after.strip()\n",
    "\n",
    "  return content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "eu7Sv3bDhXdI"
   },
   "outputs": [],
   "source": [
    "def query_llm(model_name: str, user_prompt):\n",
    "    try:\n",
    "        model = MODELS[model_name]\n",
    "\n",
    "        if \"gpt\" in model.lower():\n",
    "            response = ask_gpt(model, user_prompt)\n",
    "\n",
    "        elif \"claude\" in model.lower():\n",
    "            response = ask_claude(model, user_prompt)\n",
    "\n",
    "        elif \"gemini\" in model.lower():\n",
    "            response = ask_gemini(model, user_prompt)\n",
    "\n",
    "        elif model in HF_MODELS:\n",
    "            response = ask_hf(model, user_prompt)\n",
    "\n",
    "        else:\n",
    "            raise ValueError(f\"Unsupported model. Use one of {', '.join(MODELS.keys())}\")\n",
    "\n",
    "        lines = [line.strip() for line in response.strip().splitlines() if line.strip().startswith(\"{\")]\n",
    "\n",
    "        return [json.loads(line) for line in lines]\n",
    "\n",
    "    except Exception as e:\n",
    "        raise Exception(f\"Model query failed: {str(e)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mxuwLUsVlBlY"
   },
   "source": [
    "## Output Formatter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "IAKfqgZIlGuP"
   },
   "outputs": [],
   "source": [
    "def save_dataset(records, file_format: str, file_name: str):\n",
    "    df = pd.DataFrame(records)\n",
    "    print(df.shape)\n",
    "    if file_format == \".csv\":\n",
    "        df.to_csv(file_name, index=False)\n",
    "    elif file_format == \".tsv\":\n",
    "        df.to_csv(file_name, sep=\"\\t\", index=False)\n",
    "    elif file_format == \".jsonl\":\n",
    "        with open(file_name, \"w\") as f:\n",
    "            for record in records:\n",
    "                f.write(json.dumps(record) + \"\\n\")\n",
    "    elif file_format == \".json\":\n",
    "        df.to_json(file_name, orient=\"records\", index=False)\n",
    "    else:\n",
    "        raise ValueError(\"Unsupported file format\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "gkpkQ0nal_5B"
   },
   "outputs": [],
   "source": [
    "def generate_dataset(\n",
    "    model_name: str,\n",
    "    business_case: str,\n",
    "    num_records: int = 100,\n",
    "    schema_text: str = None,\n",
    "    file_format: str = '.jsonl',\n",
    "    file_name: str = 'test_dataset.jsonl'\n",
    "):\n",
    "    \"\"\"\n",
    "    Generates a synthetic dataset using an LLM based on the given business case and optional schema.\n",
    "\n",
    "    Returns:\n",
    "        Tuple[str, pd.DataFrame | None]: A status message and a preview DataFrame (first 10 rows) if successful.\n",
    "    \"\"\"\n",
    "    try:\n",
    "        # Validate number of records\n",
    "        if num_records <= 10:\n",
    "            return \"❌ Error: Number of records must be greater than 10.\", None\n",
    "        if num_records > 1000:\n",
    "            return \"❌ Error: Number of records must be less than or equal to 1000.\", None\n",
    "\n",
    "        # Validate file format\n",
    "        if file_format not in FILE_FORMATS:\n",
    "            return f\"❌ Error: Invalid file format '{file_format}'. Supported formats: {FILE_FORMATS}\", None\n",
    "\n",
    "        # Ensure file name has correct extension\n",
    "        if not file_name.endswith(file_format):\n",
    "            file_name += file_format\n",
    "\n",
    "        # Generate the prompt and query the model\n",
    "        prompt = get_user_prompt(business_case, schema_text, num_records)\n",
    "        records = query_llm(model_name, prompt)\n",
    "\n",
    "        if not records:\n",
    "            return \"❌ Error: No valid records were generated by the model.\", None\n",
    "\n",
    "        # Save dataset\n",
    "        save_dataset(records, file_format, file_name)\n",
    "\n",
    "        # Prepare preview\n",
    "        df = pd.DataFrame(records)\n",
    "        preview = df.head(10)\n",
    "\n",
    "        success_message = (\n",
    "            f\"✅ Generated {len(records)} records successfully!\\n\"\n",
    "            f\"📁 Saved to: {file_name}\\n\"\n",
    "        )\n",
    "\n",
    "        return success_message, preview\n",
    "\n",
    "    except Exception as e:\n",
    "        return f\"❌ Error: {str(e)}\", None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 702
    },
    "id": "Z9WdaSfFUakj",
    "outputId": "2fbce2c5-a6d3-4dd8-a9d2-0e38c18d202e"
   },
   "outputs": [],
   "source": [
    "with gr.Blocks(title=\"Synthetic Dataset Generator\", theme=gr.themes.Monochrome()) as interface:\n",
    "    tokenizer = None\n",
    "    inputs = None\n",
    "    hf_model = None\n",
    "    outputs = None\n",
    "\n",
    "    gr.Markdown(\"# Dataset Generator\")\n",
    "    gr.Markdown(\"Generate synthetic datasets using AI models\")\n",
    "\n",
    "    with gr.Row():\n",
    "        with gr.Column(scale=2):\n",
    "            schema_input = gr.Textbox(\n",
    "                label=\"Schema\",\n",
    "                value=DEFAULT_SCHEMA_TEXT,\n",
    "                lines=15,\n",
    "                placeholder=\"Define your dataset schema here... Please follow this format: Name (TYPE) - Description, example: Example\"\n",
    "            )\n",
    "\n",
    "            business_case_input = gr.Textbox(\n",
    "                label=\"Business Case\",\n",
    "                value=\"I want to generate restaurant dataset\",\n",
    "                lines=1,\n",
    "                placeholder=\"Enter business case description...\"\n",
    "            )\n",
    "\n",
    "            with gr.Row():\n",
    "                model_dropdown = gr.Dropdown(\n",
    "                    label=\"Model\",\n",
    "                    choices=list(MODELS.keys()),\n",
    "                    value=list(MODELS.keys())[0],\n",
    "                    interactive=True\n",
    "                )\n",
    "\n",
    "                nr_records_input = gr.Number(\n",
    "                    label=\"Number of records\",\n",
    "                    value=27,\n",
    "                    minimum=11,\n",
    "                    maximum=1000,\n",
    "                    step=1\n",
    "                )\n",
    "\n",
    "            with gr.Row():\n",
    "                filename_input = gr.Textbox(\n",
    "                      label=\"Save as\",\n",
    "                      value=\"restaurant_dataset\",\n",
    "                      placeholder=\"Enter filename (extension will be added automatically)\"\n",
    "                  )\n",
    "\n",
    "                file_format_dropdown = gr.Dropdown(\n",
    "                    label=\"File format\",\n",
    "                    choices=FILE_FORMATS,\n",
    "                    value=FILE_FORMATS[0],\n",
    "                    interactive=True\n",
    "                )\n",
    "\n",
    "            generate_btn = gr.Button(\"🚀 Generate\", variant=\"secondary\", size=\"lg\")\n",
    "\n",
    "        with gr.Column(scale=1):\n",
    "            gr.Markdown(\"\"\"\n",
    "            ### 📝 Dataset Generation Instructions\n",
    "\n",
    "            1. **🗂 Schema** – Define your dataset structure\n",
    "              *(default: restaurant schema provided)*\n",
    "            2. **💡 Business Case** – Enter a prompt to guide the AI for generating data\n",
    "            3. **🤖 Model** – Choose your AI model: GPT, Claude, Gemini, or Hugging Face\n",
    "            4. **📊 Number of Records** – Specify entries to generate\n",
    "              *(min: 11, max: 1000)*\n",
    "            5. **📁 File Format** – Select output type: `.csv`, `.tsv`, `.jsonl`, or `.json`\n",
    "            6. **💾 Save As** – Provide a filename *(extension auto-added)*\n",
    "            7. **🚀 Generate** – Click **Generate** to create your dataset\n",
    "\n",
    "            ### 🔧 Requirements\n",
    "\n",
    "            Set API keys in Colab’s secret section:\n",
    "              `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, `GOOGLE_API_KEY`, `HF_TOKEN`\n",
    "            \"\"\")\n",
    "            output_status = gr.Textbox(\n",
    "                label=\"Status\",\n",
    "                lines=4,\n",
    "                interactive=False\n",
    "            )\n",
    "\n",
    "            output_preview = gr.Dataframe(\n",
    "                label=\"Preview (first 10 rows)\",\n",
    "                interactive=False,\n",
    "                wrap=True\n",
    "            )\n",
    "\n",
    "    generate_btn.click(\n",
    "        fn=generate_dataset,\n",
    "        inputs=[\n",
    "            model_dropdown,\n",
    "            business_case_input,\n",
    "            nr_records_input,\n",
    "            schema_input,\n",
    "            file_format_dropdown,\n",
    "            filename_input\n",
    "        ],\n",
    "        outputs=[output_status, output_preview]\n",
    "    )\n",
    "\n",
    "interface.launch(debug=True)\n",
    "\n",
    "del tokenizer, inputs, hf_model, outputs\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "w-ewbsjInopm"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
